#!/usr/bin/env python3

import socket
import time
import sys
import os

PORT = 8880
TIMEOUT = 0.1
INTERNAL_SERVER = "\0internal_server"
SYNCFILE = "zdtm/static/socket_lock.sync"

def wait_server_addr():
    for _ignore in range(3):
        try:
            with open(SYNCFILE, "r") as f:
                addr = f.read(4)
            os.remove(SYNCFILE)

            if addr == "ipv4":
                return "127.0.0.1"

            if addr == "ipv6":
                return "::1"

            raise Exception("Invalid address type")
        except IOError:
            time.sleep(1)

    raise TimeoutError("Sync timeout: file ({}) not found".format(SYNCFILE))

if sys.argv[1] == "--post-start":
    internal_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    internal_sock.bind(INTERNAL_SERVER)
    internal_sock.listen(1)

    # Wait for test server to be ready
    server_addr = wait_server_addr()

    pid = os.fork()
    if pid == 0:
        os.setsid()  # Detach from parent

        test_sock = socket.create_connection((server_addr, PORT), TIMEOUT)

        while True:
            internal_conn, _ignore = internal_sock.accept()
            hook = internal_conn.recv(100).decode()

            try:
                if hook == "--pre-dump":
                    test_sock.sendall(str.encode("A"))

                elif hook == "--pre-restore":
                    # Data sent while TCP socket is "locked" should be retransmitted.
                    test_sock.sendall(str.encode("B"))

                elif hook == "--post-restore":
                    test_sock.sendall(str.encode("C"))

                elif hook == "--clean":
                    break

                internal_conn.sendall(str.encode("pass"))
            except Exception as e:
                internal_conn.sendall(str.encode("fail"))
                raise e
            finally:
                internal_conn.close()

    internal_sock.close()

if sys.argv[1] in ["--pre-dump", "--pre-restore", "--post-restore"]:
    internal_conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    internal_conn.connect(INTERNAL_SERVER)
    internal_conn.sendall(str.encode(sys.argv[1]))

    status = internal_conn.recv(100).decode()
    if status != "pass":
        sys.exit(1)

    internal_conn.close()

if sys.argv[1] == "--clean":
    # Clean up internal server
    internal_conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    internal_conn.connect(INTERNAL_SERVER)
    internal_conn.sendall(str.encode(sys.argv[1]))
